from sto import StoDataset, NeuralNetwork
from torch.utils.data import random_split, DataLoader
from torch import nn, optim, no_grad

def train(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    model.train()
    for batch, (X, y) in enumerate(dataloader):
        pred = model(X)
        loss = loss_fn(pred, y)

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        if batch % 6 == 0 or (batch) * dataloader.batch_size + len(X) == size:
            loss, current = loss.item(), (batch) * dataloader.batch_size + len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")

def test(dataloader, model, loss_fn):
    model.eval()
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    test_loss = 0
    with no_grad():
        for X, y in dataloader:
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
    test_loss /= num_batches
    print(f"Test Avg loss: {test_loss:>8f} \n")

def run_train():
    dataset_object = StoDataset()
    train_data_size = int(len(dataset_object) * 0.90)
    test_data_size = len(dataset_object) - train_data_size
    train_data, test_data = random_split(dataset_object, [train_data_size, test_data_size])

    batch_size = 100
    train_data_loader = DataLoader(train_data, batch_size=batch_size)
    test_data_loader = DataLoader(test_data, batch_size=batch_size)

    model = NeuralNetwork()
    loss_fn = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    epochs = 5
    for t in range(epochs):
        print(f"Epoch {t+1}\n-------------------------------")
        train(train_data_loader, model, loss_fn, optimizer)
        test(test_data_loader, model, loss_fn)
    print("Done!")

if __name__ == '__main__':
    run_train()
